{% extends 'base.attached.class' %}

{% block content %}
var {{ class_name }} = function(modelData) {
    var that = this;

    var Tree = function(lefts, rights, thresholds, indices, classes) {
        this.lefts = lefts;
        this.rights = rights;
        this.thresholds = thresholds;
        this.indices = indices;
        this.classes = classes;

        var _normVals = function(nums) {
            var i, il = nums.length;
            var result = [], sum = 0.;
            for (i = 0; i < il; i++) {
                sum += nums[i];
            }
            if(sum === 0) {
                for (i = 0; i < il; i++) {
                    result[i] = 1.0 / il;
                }
            } else {
                for (i = 0; i < il; i++) {
                    result[i] = nums[i] / sum;
                }
            }
            return result;
        };

        this._compute = function(features, node) {
            while (this.thresholds[node] !== -2) {
                if (features[this.indices[node]] <= this.thresholds[node]) {
                    node = this.lefts[node];
                } else {
                    node = this.rights[node];
                }
            }
            return _normVals(this.classes[node]);
        };
    };

    that.forest = [];

    for (var i = 0, il = modelData.length; i < il; i++) {
        that.forest.push(new Tree(
          modelData[i]['lefts'], modelData[i]['rights'],
          modelData[i]['thresholds'], modelData[i]['indices'],
          modelData[i]['classes']
        ));
    }

    var _findMax = function(nums) {
        var idx = 0;
        for (var i = 0, il = nums.length; i < il; i++) {
            idx = nums[i] > nums[idx] ? i : idx;
        }
        return idx;
    };

    this._compute = function(features) {
        var nTrees = that.forest.length,
            nClasses = that.forest[0].classes[0].length;
        var probas = new Array(nClasses).fill(0);
        var i, j;
        for (i = 0; i < nTrees; i++) {
            var temp = that.forest[i]._compute(features, 0);
            for (j = 0; j < nClasses; j++) {
                probas[j] += temp[j];
            }
        }
        for (j = 0; j < nClasses; j++) {
            probas[j] /= nTrees;
        }
        return probas;
    }

    this.predictProba = function(features) {
        return that._compute(features);
    };

    this.predict = function(features) {
        return _findMax(that.predictProba(features));
    };

};

var main = function () {
    if (typeof process !== 'undefined' && typeof process.argv !== 'undefined') {
        if (process.argv.length - 2 !== {{ n_features }}) {
            var IllegalArgumentException = function(message) {
                this.message = message;
                this.name = "IllegalArgumentException";
            }
            throw new IllegalArgumentException("You have to pass {{ n_features }} features.");
        }
    }

    // Features:
    var features = process.argv.slice(2);
    for (var i = 0; i < features.length; i++) {
        features[i] = parseFloat(features[i]);
    }

    // Model data:
    {{ model }}

    // Estimator:
    var clf = new {{ class_name }}(model);

    {% if is_test or to_json %}
    // Get JSON:
    console.log(JSON.stringify({
        "predict": clf.predict(features),
        "predict_proba": clf.predictProba(features)
    }));
    {% else %}
    // Get class prediction:
    var prediction = clf.predict(features);
    console.log("Predicted class: #" + prediction);

    // Get class probabilities:
    var probabilities = clf.predictProba(features);
    for (var i = 0; i < probabilities.length; i++) {
        console.log("Probability of class #" + i + " : " + probabilities[i]);
    }
    {% endif %}
}

if (require.main === module) {
    main();
}
{% endblock %}